package com.jhc.server;

import com.alibaba.fastjson.JSON;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.jhc.dto.BizCheckMessageDto;
import com.jhc.entity.CmsNotice;
import com.jhc.entity.CmsNoticeObject;
import com.jhc.entity.CmsUser;
import com.jhc.service.IBacNoticeObjectService;
import com.jhc.service.IBacNoticeService;
import com.jhc.service.ICmsTeacherService;
import com.jhc.service.ICmsUserService;
import com.jhc.utils.CommonResult;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.CopyOnWriteArraySet;

/**
 * @Author: zfm
 * @Date: 2019/12/4 09:02
 */
@ServerEndpoint(value = "/websocket/{staffNumber}", encoders = {WebSocketEncoder.class})
@Component
public class WebSocketServer {

    private static IBacNoticeObjectService iBacNoticeObjectService;
    private static ICmsTeacherService iCmsTeacherService;
    private static ICmsUserService iCmsUserService;
    private static IBacNoticeService iBacNoticeService;


    @Autowired
    public void setIBizNoticeService(IBacNoticeService iBacNoticeService) {
        WebSocketServer.iBacNoticeService = iBacNoticeService;
    }

    @Autowired
    public void setIBizNoticeObjectService(IBacNoticeObjectService iBacNoticeObjectService) {
        WebSocketServer.iBacNoticeObjectService = iBacNoticeObjectService;
    }

    @Autowired
    public void setICmsTeacherService(ICmsTeacherService iCmsTeacherService) {
        WebSocketServer.iCmsTeacherService = iCmsTeacherService;
    }

    /**
     * 静态变量，用来记录当前在线连接数。应该把它设计成线程安全的。
     */
    private static int onlineCount = 0;

    // concurrent包的线程安全Set，用来存放每个客户端对应的MyWebSocket对象。
    private static CopyOnWriteArraySet<WebSocketServer> webSocketSet = new CopyOnWriteArraySet<WebSocketServer>();

    /**
     * 与某个客户端的连接会话，需要通过它来给客户端发送数据
     */
    private Session session;
    /**
     * 接收sid
     */
    private String staffNumber = "";

    /**
     * 连接建立成功调用的方法
     */
    @OnOpen
    public void onOpen(Session session, @PathParam("staffNumber") String staffNumber) {
        this.session = session;

        // 加入set中
        webSocketSet.add(this);
        // 在线数加1
        addOnlineCount();
        System.out.println("有新窗口开始监听:" + staffNumber + ",当前在线人数为" + getOnlineCount());
        this.staffNumber = staffNumber;

        // 检查用户是否存在
        if (iCmsUserService.getOne(new QueryWrapper<CmsUser>().lambda()
                .eq(CmsUser::getNumber, staffNumber)) == null) {
            try {
                sendMessage(CommonResult.failed("无该用户"));
            } catch (IOException e) {
                e.printStackTrace();
            }
            return;
        }

        // 从数据库中查询未读消息并发送
        List<CmsNoticeObject> cmsNoticeObjectList = iBacNoticeObjectService.list(new QueryWrapper<CmsNoticeObject>().lambda()
                .eq(CmsNoticeObject::getUserNumber, staffNumber).eq(CmsNoticeObject::getIsRead, false));
        cmsNoticeObjectList.forEach(item -> {
            CmsNotice cmsNotice = iBacNoticeService.getById(item.getNoticeId());
            BizCheckMessageDto bizCheckMessageDto = new BizCheckMessageDto();
            BeanUtils.copyProperties(cmsNotice, bizCheckMessageDto);
            BeanUtils.copyProperties(item, bizCheckMessageDto);
            // 子消息ID
            bizCheckMessageDto.setId(item.getId());
            try {
                sendMessage(CommonResult.success(bizCheckMessageDto));
            } catch (IOException e) {
                e.printStackTrace();
            }
        });
        try {
            sendMessage(CommonResult.success("链接成功"));
        } catch (IOException e) {
            System.out.println("websocket IO异常");
        }
    }

    /**
     * 连接关闭调用的方法
     */
    @OnClose
    public void onClose() {
        // 从set中删除
        webSocketSet.remove(this);
        // 在线数减1
        subOnlineCount();
        System.out.println("有一连接关闭！当前在线人数为" + getOnlineCount());
    }

    /**
     * 收到客户端消息后调用的方法
     *
     * @param message 客户端发送过来的消息
     */
    @OnMessage
    public void onMessage(String message, Session session) {
        System.out.println("收到来自窗口" + staffNumber + "的信息:" + message);
        //群发消息
        for (WebSocketServer item : webSocketSet) {
            try {
                item.sendMessage(CommonResult.success(message));
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    /**
     * @param session
     * @param error
     */
    @OnError
    public void onError(Session session, Throwable error) {
        System.out.println("发生错误");
        error.printStackTrace();
    }

    /**
     * 实现服务器主动推送
     */
    private void sendMessage(CommonResult commonResult) throws IOException {
        try {
            this.session.getBasicRemote().sendObject(commonResult);
        } catch (EncodeException e) {
            e.printStackTrace();
        }
    }

    /**
     * 对外发送消息接口
     */
    public static void sendInfo(CommonResult commonResult, @PathParam("staffNumber") String staffNumber) throws IOException {
        System.out.println("消息内容" + JSON.toJSONString(commonResult));
        for (WebSocketServer item : webSocketSet) {
            try {
                // 这里可以设定只推送给这个staffNumber的
                if (item.staffNumber.equals(staffNumber)) {
                    item.sendMessage(commonResult);
                }
            } catch (IOException ignored) {
            }
        }
    }

    private static synchronized int getOnlineCount() {
        return onlineCount;
    }

    private static synchronized void addOnlineCount() {
        WebSocketServer.onlineCount++;
    }

    private static synchronized void subOnlineCount() {
        WebSocketServer.onlineCount--;
    }
}
